package locate

import (
	_ "embed"
	"encoding/json"
	"ewdetect/calc"
	"ewdetect/config"
	"fmt"
	"image/color"
	"math"
	"sort"

	sm "github.com/flopp/go-staticmaps"
	"github.com/fogleman/gg"
	"github.com/golang/geo/s2"
	"github.com/rs/zerolog/log"
	"gonum.org/v1/gonum/optimize"
	"gonum.org/v1/plot/palette/moreland"
)

type ComputedTableStruct struct {
	Angles []float64
	Depths []float64
	PTable [][]float64 `json:"p_table"` // Usage: PTable[depth][angle]
	STable [][]float64 `json:"s_table"`
}

//go:embed assets/computed_tables.json
var computedTableJson []byte

var ComputedTable ComputedTableStruct

func Init() {
	err := json.Unmarshal(computedTableJson, &ComputedTable)
	if err != nil {
		log.Error().Err(err).Msg("locate: failed to unmarshal computed tables")
	}
	for i := range ComputedTable.Angles {
		ComputedTable.Angles[i] = ComputedTable.Angles[i] / 180 * math.Pi // to rad
	}
	log.Info().Msg("locate: initialized computed tables")
}

func Interpolate2D(depth, angle float64, table [][]float64) float64 {
	// Find bounding indices
	dIdxLow, dIdxHigh, dFrac := findBoundingIndices(ComputedTable.Depths, depth)
	aIdxLow, aIdxHigh, aFrac := findBoundingIndices(ComputedTable.Angles, angle)

	// Get the four surrounding table values
	t11 := table[dIdxLow][aIdxLow]
	t12 := table[dIdxLow][aIdxHigh]
	t21 := table[dIdxHigh][aIdxLow]
	t22 := table[dIdxHigh][aIdxHigh]

	// Perform bilinear interpolation
	tInterpLow := (1-aFrac)*t11 + aFrac*t12
	tInterpHigh := (1-aFrac)*t21 + aFrac*t22
	tInterpolated := (1-dFrac)*tInterpLow + dFrac*tInterpHigh

	/*log.Debug().
	Float64("depth", depth).
	Float64("angle", angle).
	Float64("interpolated_value", tInterpolated).
	Msg("locate: performed 2D interpolation")*/

	return tInterpolated
}

// InterpolatePWave performs bilinear interpolation on the P-wave table.
func InterpolatePWave(depth, angle float64) float64 {
	return Interpolate2D(depth, angle, ComputedTable.PTable)
}

// InterpolateSWave performs bilinear interpolation on the S-wave table.
func InterpolateSWave(depth, angle float64) float64 {
	return Interpolate2D(depth, angle, ComputedTable.STable)
}

// findBoundingIndices locates the indices surrounding a value in a sorted slice using binary search.
func findBoundingIndices(arr []float64, val float64) (lowIdx, highIdx int, frac float64) {
	n := len(arr)
	if val <= arr[0] {
		return 0, 1, 0 // Just in case we are out of bounds
	}
	if val >= arr[n-1] {
		return n - 2, n - 1, 1 // Just in case we are out of bounds (bis)
	}

	// Binary search: Find the smallest index where arr[i] >= val
	idx := sort.Search(n, func(i int) bool { return arr[i] >= val })

	// Ensure we get two valid bounding indices
	lowIdx = idx - 1
	highIdx = idx

	if idx >= len(arr) {
		lowIdx = len(arr) - 2
		highIdx = len(arr) - 1
	}

	// Compute interpolation fraction
	frac = (val - arr[lowIdx]) / (arr[highIdx] - arr[lowIdx])

	return lowIdx, highIdx, frac
}

type Observation struct {
	StationName  string
	Lat          float64 //rad
	Lon          float64 //rad
	PWaveArrival float64 // seconds from event start epoch
	SWaveArrival float64 // seconds from event start epoch
}

type Guess struct {
	Lat   float64 //rad
	Lon   float64 //rad
	Epoch float64 // seconds from event start epoch (is normally negative, unless seismographs open a wormhole and travel back in time)
	Depth float64
}

func Square(f float64) float64 {
	return f * f
}

func GreatCircleAngle(lat1, lon1, lat2, lon2 float64) float64 { // rad
	return math.Acos(math.Sin(lat1)*math.Sin(lat2) + math.Cos(lat1)*math.Cos(lat2)*math.Cos(lon1-lon2))
}

func ErrorFunction(observations *[]Observation, guess Guess) float64 {
	cumulErrorSquared := 0.0
	for _, observation := range *observations {
		diffAngle := GreatCircleAngle(guess.Lat, guess.Lon, observation.Lat, observation.Lon)
		pDelay := observation.PWaveArrival - guess.Epoch
		sDelay := observation.SWaveArrival - guess.Epoch
		cumulErrorSquared += Square((sDelay-InterpolateSWave(guess.Depth, diffAngle))/sDelay) + Square((pDelay-InterpolatePWave(guess.Depth, diffAngle))/pDelay)
	}
	log.Debug().
		Float64("error_squared", cumulErrorSquared).
		Msg("locate: calculated error function")
	return cumulErrorSquared // no unit
}

// http://epubs.siam.org/doi/pdf/10.1137/S1052623496303470
// https://pkg.go.dev/gonum.org/v1/gonum/optimize#NelderMead
func NelderMeadOptimization(eventName string, observations *[]Observation, drawDebugMap bool) Guess {
	log.Info().
		Str("event_name", eventName).
		Int("observation_count", len(*observations)).
		Msg("locate: starting Nelder-Mead optimization")

	problem := optimize.Problem{
		Func: func(x []float64) float64 {
			return ErrorFunction(observations, Guess{
				Lat: x[0], Lon: x[1], Epoch: x[2], Depth: x[3]})
		},
	}
	sumLatX := 0.0
	sumLatY := 0.0
	sumLonX := 0.0
	sumLonY := 0.0
	for _, observation := range *observations {
		sumLatX += math.Cos(observation.Lat)
		sumLatY += math.Sin(observation.Lat)
		sumLonX += math.Cos(observation.Lon)
		sumLonY += math.Sin(observation.Lon)
	}
	meanLat := math.Atan2(sumLatY/float64(len(*observations)), sumLatX/float64(len(*observations)))
	meanLon := math.Atan2(sumLonY/float64(len(*observations)), sumLonX/float64(len(*observations)))
	result, err := optimize.Minimize(problem, []float64{meanLat, meanLon, 0, 0}, nil, &optimize.NelderMead{})
	if err != nil {
		log.Error().Err(err).Msg("locate: optimization failed")
	}

	if config.Debug && drawDebugMap {
		log.Debug().Msg("locate: generating debug map")
		minLat := result.X[0]
		maxLat := result.X[0]
		minLon := result.X[1]
		maxLon := result.X[1]
		for _, observation := range *observations {
			if observation.Lat > maxLat {
				maxLat = observation.Lat
			}
			if observation.Lon > maxLon {
				maxLon = observation.Lon
			}
			if observation.Lat < minLat {
				minLat = observation.Lat
			}
			if observation.Lon < minLon {
				minLon = observation.Lon
			}
		}
		minLat -= config.DebugBoundingBoxPadding
		minLon -= config.DebugBoundingBoxPadding
		maxLat += config.DebugBoundingBoxPadding
		maxLon += config.DebugBoundingBoxPadding
		var evals [][]float64
		minEval := math.MaxFloat64
		maxEval := 0.0
		for i := 0; i < config.DebugFunctionAreas; i++ {
			var evalLine []float64
			for j := 0; j < config.DebugFunctionAreas; j++ {
				lat := float64(i)/float64(config.DebugFunctionAreas)*minLat + (1-float64(i)/float64(config.DebugFunctionAreas))*maxLat
				lon := float64(j)/float64(config.DebugFunctionAreas)*minLon + (1-float64(j)/float64(config.DebugFunctionAreas))*maxLon
				eval := math.Sqrt(ErrorFunction(observations, Guess{
					Lat: lat, Lon: lon, Epoch: result.X[2], Depth: result.X[3]}))
				if eval > maxEval {
					maxEval = eval
				}
				if eval < minEval {
					minEval = eval
				}
				evalLine = append(evalLine, eval)

			}
			evals = append(evals, evalLine)
		}
		ctx := sm.NewContext()
		ctx.SetSize(int(config.DebugResolution*math.Abs(calc.AngleDiff(maxLon, minLon)/calc.AngleDiff(maxLat, minLat))), config.DebugResolution)
		ctx.SetBoundingBox(s2.RectFromLatLng(s2.LatLngFromDegrees(minLat*180/math.Pi, minLon*180/math.Pi)).AddPoint(s2.LatLngFromDegrees(maxLat*180/math.Pi, maxLon*180/math.Pi)))
		ctx.AddObject(
			sm.NewMarker(
				s2.LatLngFromDegrees(result.X[0]*180/math.Pi, result.X[1]*180/math.Pi),
				color.RGBA{0xff, 0, 0, 0xff},
				16.0,
			),
		)

		for _, observation := range *observations {
			ctx.AddObject(
				sm.NewMarker(
					s2.LatLngFromDegrees(observation.Lat*180/math.Pi, observation.Lon*180/math.Pi),
					color.RGBA{0, 0xff, 0, 0xff},
					16.0,
				),
			)
		}

		colormap := moreland.ExtendedBlackBody()
		colormap.SetMax(1)
		colormap.SetMin(0)
		colormap.SetAlpha(0.5)

		for i := 0; i < config.DebugFunctionAreas; i++ {
			for j := 0; j < config.DebugFunctionAreas; j++ {
				latLow := float64(i)/float64(config.DebugFunctionAreas)*minLat + (1-float64(i)/float64(config.DebugFunctionAreas))*maxLat
				lonLow := float64(j)/float64(config.DebugFunctionAreas)*minLon + (1-float64(j)/float64(config.DebugFunctionAreas))*maxLon
				latHi := float64(i+1)/float64(config.DebugFunctionAreas)*minLat + (1-float64(i+1)/float64(config.DebugFunctionAreas))*maxLat
				lonHi := float64(j+1)/float64(config.DebugFunctionAreas)*minLon + (1-float64(j+1)/float64(config.DebugFunctionAreas))*maxLon
				remappedEval := (math.Log(evals[i][j]/maxEval+math.E) - math.Log(minEval/maxEval+math.E)) / (math.Log(1+math.E) - math.Log(minEval/maxEval+math.E))
				currColor, err := colormap.At(remappedEval)
				if err != nil {
					log.Warn().Msg("locate: failed to get color from colormap")
					currColor = color.RGBA{0, 0, 0, 0}
				}
				ctx.AddObject(
					sm.NewArea(
						[]s2.LatLng{
							s2.LatLngFromDegrees(latLow*180/math.Pi, lonLow*180/math.Pi),
							s2.LatLngFromDegrees(latHi*180/math.Pi, lonLow*180/math.Pi),
							s2.LatLngFromDegrees(latHi*180/math.Pi, lonHi*180/math.Pi),
							s2.LatLngFromDegrees(latLow*180/math.Pi, lonHi*180/math.Pi),
						},
						color.RGBA{0, 0, 0, 0},
						currColor,
						1,
					),
				)

			}
		}

		ctx.OverrideAttribution(fmt.Sprintf("%s - EWDetect - Louis \"OnTake\" Dalibard - 2025.", ctx.Attribution()))

		img, err := ctx.Render()
		if err != nil {
			log.Error().Err(err).Msg("locate: failed to render debug map")
		}

		if err := gg.SavePNG("debug/event-maps/"+eventName+".png", img); err != nil {
			log.Error().Err(err).Msg("locate: failed to save debug map")
		}
		log.Debug().Str("event_name", eventName).Msg("locate: debug map generated")
	}

	log.Info().
		Float64("lat", result.X[0]).
		Float64("lon", result.X[1]).
		Float64("epoch", result.X[2]).
		Float64("depth", result.X[3]).
		Msg("locate: optimization complete")

	return Guess{
		Lat: result.X[0], Lon: result.X[1], Epoch: result.X[2], Depth: result.X[3]}
}
